Skip to content

Conversation

RalphMao
Copy link
Collaborator

@RalphMao RalphMao commented Sep 25, 2025

What does this PR do?

Type of change: New feature and example

Overview: vllm linear and moe layer quantization support is already in modelopt library. This example shows how to calibrate and serve fakequant model with vllm. For more detailed instruction, see README.

Usage

python vllm_serve_fakequant.py <model_path> -tp 8 --host 0.0.0.0 --port 8000

Testing

This example is tested with latest vllm 0.10.2, with Qwen3 MoE/dense model and Llama 3.1 series.

Tested with GSM8k.

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Serve fake-quantized models with vLLM: calibration, optional amax merging/loading, and weight folding.
    • CLI utility to convert Hugging Face amax values into vLLM format.
  • Documentation

    • New guide for Docker setup, calibration, serving, curl-based testing, optional evaluation, and notes for QAT/PTQ workflows.
  • Chores

    • Example Docker image and environment setup for the vLLM serve example; updated changelog and codeowners entry.

Copy link

coderabbitai bot commented Sep 25, 2025

Walkthrough

Adds a new vLLM serve example: Dockerfile, README, an HF→vLLM amax conversion script, and a fakequant server script that calibrates, optionally loads amax, folds weights, and monkey-patches vLLM Worker methods before starting the server.

Changes

Cohort / File(s) Summary
Containerization for vLLM serve
examples/vllm_serve/Dockerfile
New Dockerfile based on vllm/vllm-openai:v0.10.2; sets PIP options and /workspace WORKDIR, installs system deps, copies local TensorRT-Model-Optimizer (removes its .git), installs it in editable mode with all extras and dev-test, installs flash-attn==2.7.4.post1, pre-compiles CUDA extensions, installs example requirements, makes /workspace writable, clears ENTRYPOINT, and sets default CMD to /bin/bash.
Documentation
examples/vllm_serve/README.md
New README describing Docker setup, calibration/serving/testing workflows, CLI examples for running vLLM server, optional lm_eval evaluation, and guidance for merging/loading amax into quant_config for QAT/PTQ.
AMAX conversion utility
examples/vllm_serve/convert_amax_hf2vllm.py
New CLI/script adding convert_amax_hf2vllm(hf_state_dict) to merge HF amax keys into vLLM format (merge q_proj/k_proj/v_proj → qkv_proj and gate_proj/up_proj → gate_up_proj by elementwise max), plus test_conversion() and main() for file I/O, dry-run, and diagnostics.
vLLM fakequant integration
examples/vllm_serve/vllm_serve_fakequant.py
New script adding disable_compilation context manager, fakequant_run_prolog calibration flow and calibrate_loop, quant_config, optional amax loading and validation, weight folding via mtq.fold_weight, and monkey-patches to Worker.determine_available_memory and Worker.compile_or_warm_up_model; includes main() to parse args and launch the patched server.
Repo metadata & changelog
.github/CODEOWNERS, CHANGELOG.rst
Adds a CODEOWNERS entry for /examples/vllm_serve; updates CHANGELOG.rst with a 0.39 New Features line referencing PTQ/fakequant support and the new example.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant User as User
  participant CLI as vllm_serve_fakequant.py
  participant Server as vLLM Server
  participant Worker as vLLM Worker
  participant Model as Model Runner
  participant Q as Quantization (mtq)

  User->>CLI: launch with args
  CLI->>Server: start_server()
  Server->>Worker: compile_or_warm_up_model()
  note over Worker: patched new_compile_or_warm_up_model()

  rect rgba(200,230,255,0.18)
    Worker->>Worker: fakequant_run_prolog()
    Worker->>Model: load tokenizer & prepare calib data
    Worker->>Q: run calibrate_loop (inside disable_compilation)
    alt amax provided
      Worker->>Q: load amax state_dict and validate
    end
    Q->>Q: fold_weight()
  end

  Worker->>Worker: continue original compile/warmup
  Server-->>User: Ready to serve
  User->>Server: inference request
  Server->>Worker: execute_model()
  Worker->>Model: forward (with folded weights)
  Model-->>Server: outputs
  Server-->>User: responses
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~30 minutes

Poem

I hop through Docker lanes with glee,
Merging amax to qkv and me.
I calibrate, fold, then serve—so spry,
Kernels hum and latencies fly.
Carrots crunchy, models ready—bye-bye! 🥕✨

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title concisely summarizes the primary enhancement of adding fakequant serving support with the latest vLLM and the generalization of calibration logic, directly reflecting the main changes in the pull request. It clearly communicates the new feature purpose without extraneous details. The phrasing is specific to the example and calibration improvements, making it immediately understandable to collaborators.
Docstring Coverage ✅ Passed No functions found in the changes. Docstring coverage check skipped.
✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch huizim/vllm_serve_update

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d02bc95 and cfd61b4.

📒 Files selected for processing (1)
  • .github/CODEOWNERS (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • .github/CODEOWNERS
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

🧹 Nitpick comments (8)
.pre-commit-config.yaml (1)

35-39: Reverting clang-format to v9 may cause formatting drift; confirm toolchain compatibility.

v9.0.0 is quite old and lacks many later fixes. Please verify CI/dev environments and any IDE integrations rely on the same version to avoid churn. If this is for compatibility with an external toolchain, document the rationale here.

Would you like me to scan the repo for references to clang-format versions in docs/CI to confirm alignment?

examples/vllm_serve/Dockerfile (2)

32-35: Robust loop over requirements; avoid backslash escapes.

Use read -r to avoid backslash interpretation and set IFS.

-RUN find TensorRT-Model-Optimizer/examples -name "requirements.txt" | grep -v "windows" | while read req_file; do \
+RUN find TensorRT-Model-Optimizer/examples -name "requirements.txt" | grep -v "windows" | while IFS= read -r req_file; do \
         echo "Installing from $req_file"; \
         pip install -r "$req_file" || echo "Warning: Failed to install from $req_file"; \
     done

16-23: Consider .dockerignore instead of copying and pruning .git.

COPY . brings in large/unnecessary files; .dockerignore reduces context and speeds builds.

examples/vllm_serve/convert_amax_hf2vllm.py (1)

155-163: CLI UX: require input/output unless --test, and add --only-amax option.

Add a flag to save only amax keys to keep checkpoints small and avoid later confusion.

 parser = argparse.ArgumentParser(
     description="Convert amax values from HuggingFace to vLLM format"
 )
@@
 parser.add_argument("--dry-run", action="store_true", help="Show conversion without saving")
 parser.add_argument("--test", action="store_true", help="Run test with sample data")
+parser.add_argument("--only-amax", action="store_true", help="Save only *_amax keys to output")
examples/vllm_serve/vllm_serve_fakequant.py (4)

117-124: Consider left padding for calibration consistency.

Left padding often yields better calibration; dataset_utils also warns.

-    calib_dataloader = get_dataset_dataloader(
+    tokenizer.padding_side = "left"
+    calib_dataloader = get_dataset_dataloader(

223-227: determine_available_memory: return value not used; simplify.

The original returns None. Returning results is harmless, but you can just call and return None to match signature.

-def new_determine_available_memory(self) -> None:
-    with disable_compilation(self.model_runner.model):
-        results = old_determine_available_memory(self)
-    return results
+def new_determine_available_memory(self) -> None:
+    with disable_compilation(self.model_runner.model):
+        old_determine_available_memory(self)

229-233: Prolog trigger OK; confirm idempotency.

compile_or_warm_up_model may be called more than once. Ensure fakequant_run_prolog is idempotent or guarded.


1-52: Duplicate/stacked license headers; consolidate to avoid confusion.

The file contains multiple Apache-2.0 and MIT blocks. Keep a single project header and add third-party notices in a dedicated section or 3rd-party notices file.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4ff8fc9 and c03be05.

📒 Files selected for processing (5)
  • .pre-commit-config.yaml (1 hunks)
  • examples/vllm_serve/Dockerfile (1 hunks)
  • examples/vllm_serve/README.md (1 hunks)
  • examples/vllm_serve/convert_amax_hf2vllm.py (1 hunks)
  • examples/vllm_serve/vllm_serve_fakequant.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/vllm_serve/vllm_serve_fakequant.py (3)
modelopt/torch/utils/dataset_utils.py (1)
  • get_dataset_dataloader (157-232)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
  • to (115-123)
examples/vllm_serve/convert_amax_hf2vllm.py (1)
  • main (155-209)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (7)
examples/vllm_serve/README.md (1)

24-25: Confirm CLI flags validity across vLLM versions.

The short flag -tp may differ across versions. Ensure make_arg_parser exposes -tp; otherwise use --tensor-parallel-size.

examples/vllm_serve/Dockerfile (1)

26-27: Flash-Attn version compatibility.

Pinning flash-attn to 2.7.4.post1 may conflict with the base image CUDA/PyTorch versions. Please confirm ABI/Wheel availability for the base image.

examples/vllm_serve/convert_amax_hf2vllm.py (2)

71-85: Merging logic OK; consider dtype/device preservation and scalar handling.

torch.stack + max preserves dtype/device; fine. Ensure q/k/v shapes match; else raise with a clear error.


191-201: Dry-run output helpful.

Good observability of key changes.

examples/vllm_serve/vllm_serve_fakequant.py (3)

171-175: PP intermediate tensor handoff may differ across vLLM versions.

recv_tensor_dict signature and gating on is_first_rank can vary (0.10 vs 0.11). Please validate on both targets.


183-185: Rank-0 printing only: good.

Prevents noisy logs in distributed runs.


241-251: Server startup path likely OK.

FlexibleArgumentParser + make_arg_parser + uvloop.run(run_server(args)) matches vLLM patterns.

Comment on lines +206 to +209
print(f"Saving vLLM checkpoint to: {args.output}")
os.makedirs(os.path.dirname(args.output), exist_ok=True)
torch.save(vllm_state_dict, args.output)
print("Conversion complete!")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Bug: os.makedirs with empty dirname fails when output is a filename.

Guard dirname before creating.

-print(f"Saving vLLM checkpoint to: {args.output}")
-os.makedirs(os.path.dirname(args.output), exist_ok=True)
-torch.save(vllm_state_dict, args.output)
+print(f"Saving vLLM checkpoint to: {args.output}")
+out_dir = os.path.dirname(args.output)
+if out_dir:
+    os.makedirs(out_dir, exist_ok=True)
+to_save = vllm_state_dict
+if getattr(args, "only-amax", False):
+    to_save = {k: v for k, v in vllm_state_dict.items() if k.endswith("_amax")}
+torch.save(to_save, args.output)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print(f"Saving vLLM checkpoint to: {args.output}")
os.makedirs(os.path.dirname(args.output), exist_ok=True)
torch.save(vllm_state_dict, args.output)
print("Conversion complete!")
print(f"Saving vLLM checkpoint to: {args.output}")
out_dir = os.path.dirname(args.output)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
to_save = vllm_state_dict
if getattr(args, "only-amax", False):
to_save = {k: v for k, v in vllm_state_dict.items() if k.endswith("_amax")}
torch.save(to_save, args.output)
print("Conversion complete!")
🤖 Prompt for AI Agents
In examples/vllm_serve/convert_amax_hf2vllm.py around lines 206 to 209,
os.makedirs(os.path.dirname(args.output), exist_ok=True) will raise if
os.path.dirname(args.output) is empty (when args.output is just a filename);
guard the dirname before creating directories by computing d =
os.path.dirname(args.output) and only call os.makedirs(d, exist_ok=True) if d is
non-empty (or truthy), otherwise skip mkdir and proceed to torch.save.

@@ -0,0 +1,44 @@
FROM vllm/vllm-openai:v0.10.2
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Align base image with tested/targeted vLLM version.

Base is v0.10.2; README claims tests with 0.9.0 and 0.11.2. Pick one target (prefer latest tested) and update image, or document the intended version matrix.

-FROM vllm/vllm-openai:v0.10.2
+# Consider aligning with the tested version, e.g.:
+# FROM vllm/vllm-openai:v0.11.2
+FROM vllm/vllm-openai:v0.11.2
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
FROM vllm/vllm-openai:v0.10.2
# Consider aligning with the tested version, e.g.:
# FROM vllm/vllm-openai:v0.11.2
FROM vllm/vllm-openai:v0.11.2
🤖 Prompt for AI Agents
In examples/vllm_serve/Dockerfile around line 1, the base image tag (v0.10.2)
does not match the README's tested versions (0.9.0 and 0.11.2); update the
Dockerfile to use the intended target image (e.g., FROM
vllm/vllm-openai:v0.11.2) to align with the latest tested version, or
alternatively add a comment/README entry documenting the supported vLLM version
matrix and rationale for keeping v0.10.2.

done

# Allow users to run without root
RUN chmod -R 777 /workspace
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid 777; prefer least privilege.

chmod -R 777 is unsafe. Grant group write as needed or create a non-root user.

-RUN chmod -R 777 /workspace
+# Example: create non-root user and set appropriate permissions
+RUN useradd -m -u 1000 appuser \
+ && chown -R appuser:appuser /workspace
+USER appuser
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
RUN chmod -R 777 /workspace
# Example: create non-root user and set appropriate permissions
RUN useradd -m -u 1000 appuser \
&& chown -R appuser:appuser /workspace
USER appuser
🤖 Prompt for AI Agents
In examples/vllm_serve/Dockerfile around line 38, replace the unsafe "chmod -R
777 /workspace" with a least-privilege approach: create a non-root user and
group, chown the workspace to that user/group, and set restrictive perms (e.g.,
755 for dirs and 644 for files or 775/g+w only where group write is required)
instead of global 777; then switch to that USER with a USER directive so the
container runs without root privileges. Ensure recursive ownership is limited to
the workspace path and avoid granting execute/write bits to all users.


Compared with realquant, fakequant is 2-5x slower, but doesn't require dedicated kernel support and facilitates research.

This example is tested with vllm 0.9.0 and 0.11.2
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Version mismatch with Docker base image.

README states testing with vLLM 0.9.0 and 0.11.2, but the Dockerfile uses v0.10.2. Align the example and base image or document supported versions explicitly.

🤖 Prompt for AI Agents
In examples/vllm_serve/README.md around line 7, the README claims testing with
vLLM 0.9.0 and 0.11.2 while the Dockerfile uses v0.10.2; update either the
README or the Dockerfile so versions align or explicitly list all supported vLLM
versions. Fix by choosing one source of truth: (a) change the README to state
the Dockerfile's v0.10.2 and note any other supported versions, or (b) update
the Dockerfile to one of the versions listed in the README; also add a short
note in the README explaining the tested and supported vLLM versions and any
compatibility caveats.

Comment on lines +110 to +112
if tokenizer.pad_token != "<unk>" or tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix pad_token condition.

Set pad_token when missing or equals unk; current logic is inverted.

-    if tokenizer.pad_token != "<unk>" or tokenizer.pad_token is None:
-        tokenizer.pad_token = tokenizer.eos_token
+    if tokenizer.pad_token is None or tokenizer.pad_token == getattr(tokenizer, "unk_token", "<unk>"):
+        tokenizer.pad_token = tokenizer.eos_token
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if tokenizer.pad_token != "<unk>" or tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.pad_token is None or tokenizer.pad_token == getattr(tokenizer, "unk_token", "<unk>"):
tokenizer.pad_token = tokenizer.eos_token
🤖 Prompt for AI Agents
In examples/vllm_serve/vllm_serve_fakequant.py around lines 110 to 112, the
conditional that sets tokenizer.pad_token is inverted; change it to set
pad_token when it is missing or equals "<unk>" (i.e., if tokenizer.pad_token is
None or tokenizer.pad_token == "<unk>": set tokenizer.pad_token =
tokenizer.eos_token) so the pad token is replaced only when absent or explicitly
"<unk>".

Comment on lines +177 to +181
quant_cfg = getattr(mtq, quant_config["quant_format"])

with disable_compilation(self.model):
mtq.quantize(self.model, quant_cfg, forward_loop=calibrate_loop)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Handle invalid quant_format gracefully.

Guard getattr and raise a clear error if the format name is unknown.

-    quant_cfg = getattr(mtq, quant_config["quant_format"])
+    try:
+        quant_cfg = getattr(mtq, quant_config["quant_format"])
+    except AttributeError as e:
+        raise ValueError(f"Unknown quant_format: {quant_config['quant_format']}") from e
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
quant_cfg = getattr(mtq, quant_config["quant_format"])
with disable_compilation(self.model):
mtq.quantize(self.model, quant_cfg, forward_loop=calibrate_loop)
try:
quant_cfg = getattr(mtq, quant_config["quant_format"])
except AttributeError as e:
raise ValueError(f"Unknown quant_format: {quant_config['quant_format']}") from e
with disable_compilation(self.model):
mtq.quantize(self.model, quant_cfg, forward_loop=calibrate_loop)
🤖 Prompt for AI Agents
In examples/vllm_serve/vllm_serve_fakequant.py around lines 177 to 181, the code
uses getattr(mtq, quant_config["quant_format"]) without validating the attribute
and will raise an unclear AttributeError for unknown formats; update this to
check for the attribute (e.g., using hasattr or try/except AttributeError) and
if missing raise a clear ValueError (or RuntimeError) that includes the invalid
quant_format and optionally lists supported formats, before proceeding to call
mtq.quantize.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pls add this directory in https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/.github/CODEOWNERS and use one of the existing reviewer teams or create a new one

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add this new example in Changelog under a new 0.39 section in https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst

Copy link

@mxinO mxinO left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks.

if tokenizer.pad_token != "<unk>" or tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

if quant_config["amax_file_path"]:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check if there's a "amax.pt" file inside the model folder if the amax path is not specified?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather not do this as silently loading a checkpoint seems fishy to me. Loading a checkpoint should be explicitly specified

@RalphMao RalphMao force-pushed the huizim/vllm_serve_update branch from c03be05 to 6d69e05 Compare October 1, 2025 17:53
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (9)
examples/vllm_serve/Dockerfile (2)

1-1: Version mismatch with README.

The Dockerfile uses v0.10.2 but the README (line 7) states testing with vLLM 0.9.0 and 0.11.2. Align the versions or document the supported version matrix.


38-38: Security issue: overly permissive permissions.

chmod -R 777 grants all users read/write/execute on /workspace, which is unsafe. Create a non-root user and set restrictive permissions instead.

Apply this diff to improve security:

-RUN chmod -R 777 /workspace
+RUN useradd -m -u 1000 appuser \
+ && chown -R appuser:appuser /workspace
+USER appuser
examples/vllm_serve/convert_amax_hf2vllm.py (1)

206-209: Bug: directory creation fails when output is a bare filename.

os.path.dirname(args.output) returns an empty string for bare filenames, causing os.makedirs to fail or misbehave. Guard against empty dirname.

Apply this diff:

 print(f"Saving vLLM checkpoint to: {args.output}")
-os.makedirs(os.path.dirname(args.output), exist_ok=True)
+out_dir = os.path.dirname(args.output)
+if out_dir:
+    os.makedirs(out_dir, exist_ok=True)
 torch.save(vllm_state_dict, args.output)
examples/vllm_serve/vllm_serve_fakequant.py (4)

75-95: Raises on unknown model types.

disable_compilation throws ValueError if the model lacks both model and language_model attributes. Make it tolerant by falling back to the object itself or no-op.

Apply this diff:

 @contextmanager
-def disable_compilation(model):
-    """Context manager to temporarily disable torch.compile"""
-    do_not_compile = True
-    if hasattr(model, "model"):
-        do_not_compile = model.model.do_not_compile
-        model.model.do_not_compile = True
-    elif hasattr(model, "language_model"):  # VLM requires this
-        do_not_compile = model.language_model.model.do_not_compile
-        model.language_model.model.do_not_compile = True
-    else:
-        raise ValueError("Model does not have a model or language_model attribute")
-
+def disable_compilation(obj):
+    """Temporarily set do_not_compile on the underlying model if available."""
+    target = None
+    if hasattr(obj, "model"):
+        target = obj.model
+    elif hasattr(obj, "language_model") and hasattr(obj.language_model, "model"):
+        target = obj.language_model.model
+    else:
+        target = obj
+    old = getattr(target, "do_not_compile", False)
     try:
+        setattr(target, "do_not_compile", True)
         yield
     finally:
-        if hasattr(model, "model"):
-            model.model.do_not_compile = do_not_compile
-        elif hasattr(model, "language_model"):
-            model.language_model.model.do_not_compile = do_not_compile
+        setattr(target, "do_not_compile", old)

110-112: Bug: pad_token condition is inverted.

The current logic sets pad_token when it is NOT "<unk>" OR is None, which is incorrect. It should set pad_token when it IS None OR equals "<unk>".

Apply this diff:

-    if tokenizer.pad_token != "<unk>" or tokenizer.pad_token is None:
+    if tokenizer.pad_token is None or tokenizer.pad_token == "<unk>":
         tokenizer.pad_token = tokenizer.eos_token

177-181: Missing validation for quant_format.

getattr(mtq, quant_config["quant_format"]) will raise an unclear AttributeError if the format is invalid. Validate and provide a clear error message.

Apply this diff:

-    quant_cfg = getattr(mtq, quant_config["quant_format"])
+    try:
+        quant_cfg = getattr(mtq, quant_config["quant_format"])
+    except AttributeError as e:
+        raise ValueError(
+            f"Unknown quant_format: {quant_config['quant_format']}. "
+            f"Check modelopt.torch.quantization for valid formats."
+        ) from e

187-213: Weak validation: count-only check doesn't verify key mapping.

Checking only amax key counts can miss mismatches when keys differ but counts match. Filter to _amax keys, verify exact key correspondence, and report missing/extra keys.

Apply this diff:

     amax_file_path = quant_config["amax_file_path"]
     if amax_file_path:
         print(f"Loading amax values from {amax_file_path}")
         saved_amax_dict = torch.load(amax_file_path, map_location=self.device)
         current_state_dict = self.model.state_dict()
 
-        # Count amax keys in checkpoint and model
-        checkpoint_amax_keys = [key for key in saved_amax_dict if key.endswith("amax")]
-        model_amax_keys = [key for key in current_state_dict if key.endswith("amax")]
+        # Filter to amax keys only
+        saved_amax = {k: v for k, v in saved_amax_dict.items() if "_amax" in k}
+        model_amax_keys = {k for k in current_state_dict if "_amax" in k}
 
-        checkpoint_amax_count = len(checkpoint_amax_keys)
-        model_amax_count = len(model_amax_keys)
-
-        # Ensure counts match
-        if checkpoint_amax_count != model_amax_count:
-            raise ValueError(
-                f"Mismatch in amax key counts: checkpoint has {checkpoint_amax_count} "
-                f"amax keys but model has {model_amax_count} amax keys. "
-            )
+        missing_in_model = set(saved_amax.keys()) - model_amax_keys
+        extra_in_model = model_amax_keys - set(saved_amax.keys())
+        
+        if missing_in_model or extra_in_model:
+            error_msg = "Amax key mismatch:\n"
+            if missing_in_model:
+                error_msg += f"  Keys in file but not in model: {sorted(list(missing_in_model))[:5]}\n"
+            if extra_in_model:
+                error_msg += f"  Keys in model but not in file: {sorted(list(extra_in_model))[:5]}\n"
+            raise ValueError(error_msg)
 
-        # Update amax values
-        for key, value in saved_amax_dict.items():
-            if key in current_state_dict:
-                current_state_dict[key] = value.to(self.device)
+        # Update amax values only
+        with torch.no_grad():
+            for key, value in saved_amax.items():
+                current_state_dict[key] = value.to(self.device)
 
         self.model.load_state_dict(current_state_dict, strict=True)
examples/vllm_serve/README.md (2)

7-7: Version mismatch with Dockerfile.

The README claims testing with vLLM 0.9.0 and 0.11.2, but the Dockerfile uses v0.10.2. Align these or explicitly document the version matrix.


19-19: Filename typo.

The referenced filename vllm_serve_fake_quant.py has an extra underscore. The actual file is vllm_serve_fakequant.py.

Apply this diff:

-Step 1: Modify `quant_config` in `vllm_serve_fake_quant.py` for the desired quantization format
+Step 1: Modify `quant_config` in `vllm_serve_fakequant.py` for the desired quantization format
🧹 Nitpick comments (3)
examples/vllm_serve/Dockerfile (1)

29-29: Document why CUDA extension compilation failures are acceptable.

The || true swallows compilation errors. If precompilation is optional, add a brief comment explaining why failures are tolerated.

-RUN python3 -c "import modelopt.torch.quantization.extensions as ext; ext.precompile()" || true
+# Pre-compile CUDA extensions (optional; compilation may fail in some environments)
+RUN python3 -c "import modelopt.torch.quantization.extensions as ext; ext.precompile()" || true
examples/vllm_serve/vllm_serve_fakequant.py (1)

113-116: Consider auto-detecting amax.pt in model directory.

As suggested in past review, checking for amax.pt in the model folder when amax_file_path is None would improve user experience.

Apply this diff:

+    # Auto-detect amax file if not specified
+    if not quant_config["amax_file_path"]:
+        default_amax = os.path.join(self.model_config.model, "amax.pt")
+        if os.path.isfile(default_amax):
+            quant_config["amax_file_path"] = default_amax
+            print(f"Auto-detected amax file at {default_amax}")
+    
     if quant_config["amax_file_path"]:
         # If amax file path is provided, we only need to do a simple calibration step
         quant_config["quant_num_samples"] = 1
examples/vllm_serve/README.md (1)

46-56: Section marked WIP but provides complete instructions.

The heading says "(WIP)" but the steps appear complete and functional. Either remove the WIP tag or clarify what remains unfinished.

-## Load QAT/PTQ model and serve in vLLM (WIP)
+## Load QAT/PTQ model and serve in vLLM

Or clarify what's incomplete:

-## Load QAT/PTQ model and serve in vLLM (WIP)
+## Load QAT/PTQ model and serve in vLLM (WIP: only tested for Llama 3.1)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c03be05 and 6d69e05.

📒 Files selected for processing (4)
  • examples/vllm_serve/Dockerfile (1 hunks)
  • examples/vllm_serve/README.md (1 hunks)
  • examples/vllm_serve/convert_amax_hf2vllm.py (1 hunks)
  • examples/vllm_serve/vllm_serve_fakequant.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/vllm_serve/vllm_serve_fakequant.py (2)
modelopt/torch/utils/dataset_utils.py (1)
  • get_dataset_dataloader (157-232)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
  • to (115-123)
examples/vllm_serve/convert_amax_hf2vllm.py (1)
examples/vllm_serve/vllm_serve_fakequant.py (1)
  • main (241-250)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (4)
examples/vllm_serve/convert_amax_hf2vllm.py (2)

26-86: LGTM: Conversion logic is sound.

The function correctly merges QKV and gate/up projection amax values using elementwise max, which is appropriate for quantization ranges.


88-153: LGTM: Test coverage validates key transformations.

The test function provides clear validation of the merging logic with expected vs. actual key comparisons.

examples/vllm_serve/vllm_serve_fakequant.py (1)

125-176: LGTM: Calibration loop correctly integrates with vLLM's execution model.

The synthetic request construction and scheduler output assembly properly mirror vLLM's internal structures for calibration.

examples/vllm_serve/README.md (1)

1-57: LGTM: Clear documentation of the workflow.

The README provides a well-structured guide covering environment setup, calibration, serving, and evaluation with concrete examples.

Copy link
Contributor

@realAsma realAsma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!!

@RalphMao RalphMao force-pushed the huizim/vllm_serve_update branch from 6d69e05 to 39d7ded Compare October 3, 2025 22:20
@RalphMao RalphMao requested a review from a team as a code owner October 3, 2025 22:20
@RalphMao RalphMao force-pushed the huizim/vllm_serve_update branch from 39d7ded to 5f1ed83 Compare October 3, 2025 23:06
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
examples/vllm_serve/vllm_serve_fakequant.py (1)

187-213: Consider stricter amax key validation.

While the count-based check (lines 197-205) catches gross mismatches, it doesn't verify that checkpoint keys actually map to model keys. If the checkpoint contains different amax keys (e.g., from a different model architecture), the count could match but keys could be misaligned.

Consider validating key sets explicitly:

-    checkpoint_amax_keys = [key for key in saved_amax_dict if key.endswith("amax")]
-    model_amax_keys = [key for key in current_state_dict if key.endswith("amax")]
-
-    checkpoint_amax_count = len(checkpoint_amax_keys)
-    model_amax_count = len(model_amax_keys)
-
-    # Ensure counts match
-    if checkpoint_amax_count != model_amax_count:
-        raise ValueError(
-            f"Mismatch in amax key counts: checkpoint has {checkpoint_amax_count} "
-            f"amax keys but model has {model_amax_count} amax keys. "
-        )
+    # Filter to amax keys only
+    checkpoint_amax_keys = {k for k in saved_amax_dict if k.endswith("_amax") or k.endswith("amax")}
+    model_amax_keys = {k for k in current_state_dict if k.endswith("_amax") or k.endswith("amax")}
+
+    # Verify key sets match
+    missing_in_model = checkpoint_amax_keys - model_amax_keys
+    extra_in_model = model_amax_keys - checkpoint_amax_keys
+
+    if missing_in_model or extra_in_model:
+        raise ValueError(
+            f"Amax key mismatch:\n"
+            f"  Keys in checkpoint not found in model: {sorted(list(missing_in_model))[:5]}...\n"
+            f"  Keys in model not found in checkpoint: {sorted(list(extra_in_model))[:5]}..."
+        )
 
     # Update amax values
-    for key, value in saved_amax_dict.items():
-        if key in current_state_dict:
+    with torch.no_grad():
+        for key in checkpoint_amax_keys:
+            value = saved_amax_dict[key]
             current_state_dict[key] = value.to(self.device)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 39d7ded and 5f1ed83.

📒 Files selected for processing (5)
  • examples/vllm_serve/Dockerfile (1 hunks)
  • examples/vllm_serve/README.md (1 hunks)
  • examples/vllm_serve/convert_amax_hf2vllm.py (1 hunks)
  • examples/vllm_serve/vllm_serve_fakequant.py (1 hunks)
  • modelopt/torch/quantization/plugins/vllm.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • examples/vllm_serve/README.md
  • modelopt/torch/quantization/plugins/vllm.py
🧰 Additional context used
🧬 Code graph analysis (1)
examples/vllm_serve/vllm_serve_fakequant.py (3)
modelopt/torch/utils/dataset_utils.py (1)
  • get_dataset_dataloader (157-232)
examples/llm_autodeploy/api_server.py (1)
  • run_server (194-208)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
  • to (115-123)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (8)
examples/vllm_serve/Dockerfile (2)

29-29: LGTM: Pre-compile with fallback.

Using || true allows the build to continue if CUDA extension pre-compilation fails, which is appropriate for environments where compilation might not be immediately possible.


32-35: LGTM: Requirements install with error handling.

The script iterates through requirements files, attempts installation, and logs warnings on failure without stopping the build. This is a pragmatic approach for optional dependencies.

examples/vllm_serve/convert_amax_hf2vllm.py (2)

26-85: LGTM: Conversion logic is sound.

The function correctly:

  • Preserves non-amax keys unchanged.
  • Groups q/k/v and gate/up amax keys using regex patterns.
  • Merges grouped keys by taking elementwise max via torch.stack().max(dim=0)[0].
  • Handles single-key groups by renaming.

The regex patterns and merging strategy align with the vLLM format requirements described in the PR.


88-153: LGTM: Test function validates conversion.

The test creates a representative HF state dict with q/k/v and gate/up amax keys, runs the conversion, and verifies the output matches expected vLLM keys using set comparison. This provides a basic sanity check for the conversion logic.

examples/vllm_serve/vllm_serve_fakequant.py (4)

113-123: LGTM: Efficient calibration when amax is pre-computed.

Setting quant_num_samples = 1 when amax_file_path is provided optimizes calibration by performing only a minimal forward pass before loading pre-computed amax values. The dataloader setup correctly uses the modelopt utility function.


241-250: LGTM: Main entry point structure.

The main function correctly:

  • Creates a FlexibleArgumentParser for vLLM arguments.
  • Adds a positional model argument.
  • Uses make_arg_parser to augment with vLLM CLI arguments.
  • Launches the server with uvloop for async execution.

223-226: Verify new_determine_available_memory return type
The wrapper is annotated -> None but returns results from the original determine_available_memory. Confirm what the original method returns (e.g., memory stats) and update the annotation or adjust the return accordingly.


169-175: Check get_pp_group() behavior on non-distributed setups
get_pp_group() comes from vllm.distributed.parallel_state; verify that when torch.distributed isn’t initialized (e.g. single-GPU) calling .is_first_rank or .recv_tensor_dict won’t error.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (4)
examples/vllm_serve/vllm_serve_fakequant.py (4)

75-96: Make disable_compilation tolerant; avoid ValueError and support direct do_not_compile.

Current version raises on unknown types and will break when wrapping self.model_runner.model. Use a permissive target resolution and always restore state.

-@contextmanager
-def disable_compilation(model):
-    """Context manager to temporarily disable torch.compile"""
-    do_not_compile = True
-    if hasattr(model, "model"):
-        do_not_compile = model.model.do_not_compile
-        model.model.do_not_compile = True
-    elif hasattr(model, "language_model"):  # VLM requires this
-        do_not_compile = model.language_model.model.do_not_compile
-        model.language_model.model.do_not_compile = True
-    else:
-        raise ValueError("Model does not have a model or language_model attribute")
-
-    try:
-        yield
-    finally:
-        if hasattr(model, "model"):
-            model.model.do_not_compile = do_not_compile
-        elif hasattr(model, "language_model"):
-            model.language_model.model.do_not_compile = do_not_compile
+@contextmanager
+def disable_compilation(obj):
+    """Temporarily set do_not_compile on the underlying model if available."""
+    if hasattr(obj, "model"):
+        target = obj.model
+    elif hasattr(obj, "language_model") and hasattr(obj.language_model, "model"):
+        target = obj.language_model.model
+    else:
+        target = obj  # try object itself
+    old = getattr(target, "do_not_compile", False)
+    try:
+        setattr(target, "do_not_compile", True)
+        yield
+    finally:
+        setattr(target, "do_not_compile", old)

110-112: Fix pad_token condition (set when missing or equals unk).

Current logic is inverted; it overwrites valid pad tokens.

-    if tokenizer.pad_token != "<unk>" or tokenizer.pad_token is None:
-        tokenizer.pad_token = tokenizer.eos_token
+    if tokenizer.pad_token is None or tokenizer.pad_token == getattr(tokenizer, "unk_token", "<unk>"):
+        tokenizer.pad_token = tokenizer.eos_token

177-181: Guard unknown quant_format with a clear error.

getattr without validation raises ambiguous AttributeError.

-    quant_cfg = getattr(mtq, quant_config["quant_format"])
+    fmt = quant_config["quant_format"]
+    try:
+        quant_cfg = getattr(mtq, fmt)
+    except AttributeError as e:
+        supported = [n for n in dir(mtq) if n.endswith("_CFG")]
+        raise ValueError(f"Unknown quant_format: {fmt}. Supported: {supported}") from e

187-213: *Amax override: update only _amax keys and verify mapping; use safe torch.load.

Counts can match while keys mismatch; broad updates risk corrupting non-amax params. Also prefer weights_only when available.

-    amax_file_path = quant_config["amax_file_path"]
-    if amax_file_path:
-        print(f"Loading amax values from {amax_file_path}")
-        saved_amax_dict = torch.load(amax_file_path, map_location=self.device)
-        current_state_dict = self.model.state_dict()
-
-        # Count amax keys in checkpoint and model
-        checkpoint_amax_keys = [key for key in saved_amax_dict if key.endswith("amax")]
-        model_amax_keys = [key for key in current_state_dict if key.endswith("amax")]
-
-        checkpoint_amax_count = len(checkpoint_amax_keys)
-        model_amax_count = len(model_amax_keys)
-
-        # Ensure counts match
-        if checkpoint_amax_count != model_amax_count:
-            raise ValueError(
-                f"Mismatch in amax key counts: checkpoint has {checkpoint_amax_count} "
-                f"amax keys but model has {model_amax_count} amax keys. "
-            )
-
-        # Update amax values
-        for key, value in saved_amax_dict.items():
-            if key in current_state_dict:
-                current_state_dict[key] = value.to(self.device)
-
-        self.model.load_state_dict(current_state_dict, strict=True)
+    amax_file_path = quant_config["amax_file_path"]
+    if amax_file_path:
+        print(f"Loading amax values from {amax_file_path}")
+        try:
+            saved_amax_dict = torch.load(amax_file_path, map_location="cpu", weights_only=True)
+        except TypeError:  # older torch
+            saved_amax_dict = torch.load(amax_file_path, map_location="cpu")
+        current_state_dict = self.model.state_dict()
+
+        # Filter amax-only keys
+        saved_amax = {k: v for k, v in saved_amax_dict.items() if k.endswith("_amax") or k.endswith("amax")}
+        model_amax_keys = {k for k in current_state_dict if k.endswith("_amax") or k.endswith("amax")}
+
+        missing_in_model = set(saved_amax.keys()) - model_amax_keys
+        missing_in_ckpt = model_amax_keys - set(saved_amax.keys())
+        if missing_in_model or missing_in_ckpt:
+            raise ValueError(
+                f"Amax key mismatch. "
+                f"Missing in model: {sorted(list(missing_in_model))[:10]} "
+                f"Missing in checkpoint: {sorted(list(missing_in_ckpt))[:10]}"
+            )
+
+        # Update amax values only
+        with torch.no_grad():
+            for key, value in saved_amax.items():
+                current_state_dict[key] = value.to(self.device)
+
+        self.model.load_state_dict(current_state_dict, strict=True)
🧹 Nitpick comments (2)
examples/vllm_serve/vllm_serve_fakequant.py (2)

113-116: Avoid mutating global quant_config inside runtime path.

Mutating quant_num_samples is global state and can leak across runs. Copy locally.

Example refactor:

cfg = dict(quant_config)
if cfg["amax_file_path"]:
    cfg["quant_num_samples"] = 1
# then use cfg[...] instead of quant_config[...] in this function

125-131: Limit tqdm to rank 0 to avoid multi-process spam.

Multiple workers will print overlapping bars. Show progress only on rank 0.

-        print("Calibrating model...")
-        for batch_idx, batch in tqdm(enumerate(calib_dataloader)):
+        print("Calibrating model...")
+        is_rank0 = (not torch.distributed.is_initialized()) or (torch.distributed.get_rank() == 0)
+        iterator = enumerate(calib_dataloader)
+        if is_rank0:
+            iterator = tqdm(iterator)
+        for batch_idx, batch in iterator:

Add import (outside this hunk):

import torch.distributed  # alias as torch.distributed
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5f1ed83 and f8e1dd0.

📒 Files selected for processing (1)
  • examples/vllm_serve/vllm_serve_fakequant.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/vllm_serve/vllm_serve_fakequant.py (3)
modelopt/torch/utils/dataset_utils.py (1)
  • get_dataset_dataloader (157-232)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
  • to (115-123)
examples/vllm_serve/convert_amax_hf2vllm.py (1)
  • main (155-209)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: code-quality
  • GitHub Check: build-docs

Comment on lines +241 to +246
def main():
# Create parser that handles both quant and serve arguments
parser = FlexibleArgumentParser(description="vLLM model server with quantization support")
parser.add_argument("model", type=str, help="The path or name of the model to serve")
parser = make_arg_parser(parser)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Argparse bug: duplicate 'model' positional. Let make_arg_parser add it.

Adding 'model' before make_arg_parser will conflict since vLLM adds it too.

-def main():
-    # Create parser that handles both quant and serve arguments
-    parser = FlexibleArgumentParser(description="vLLM model server with quantization support")
-    parser.add_argument("model", type=str, help="The path or name of the model to serve")
-    parser = make_arg_parser(parser)
+def main():
+    # Create parser that handles both quant and serve arguments
+    parser = make_arg_parser(FlexibleArgumentParser(description="vLLM model server with quantization support"))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def main():
# Create parser that handles both quant and serve arguments
parser = FlexibleArgumentParser(description="vLLM model server with quantization support")
parser.add_argument("model", type=str, help="The path or name of the model to serve")
parser = make_arg_parser(parser)
def main():
# Create parser that handles both quant and serve arguments
parser = make_arg_parser(
FlexibleArgumentParser(description="vLLM model server with quantization support")
)
🤖 Prompt for AI Agents
In examples/vllm_serve/vllm_serve_fakequant.py around lines 241 to 246, the code
adds a positional "model" argument before calling make_arg_parser which also
adds "model", causing a duplicate argparse conflict; remove the explicit
parser.add_argument("model", ...) line (or guard it so it isn't added when using
make_arg_parser) and simply create the FlexibleArgumentParser and pass it to
make_arg_parser so only one "model" positional is defined.

@RalphMao RalphMao force-pushed the huizim/vllm_serve_update branch from f8e1dd0 to 5279884 Compare October 6, 2025 18:27
Copy link

codecov bot commented Oct 6, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.79%. Comparing base (340eb7a) to head (cfd61b4).
⚠️ Report is 3 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #369   +/-   ##
=======================================
  Coverage   73.79%   73.79%           
=======================================
  Files         171      171           
  Lines       17591    17591           
=======================================
  Hits        12982    12982           
  Misses       4609     4609           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@RalphMao RalphMao force-pushed the huizim/vllm_serve_update branch from 5279884 to d02bc95 Compare October 7, 2025 17:07
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (5)
examples/vllm_serve/convert_amax_hf2vllm.py (1)

206-208: Guard os.makedirs against empty dirname.

When args.output is just a filename without a directory path, os.path.dirname(args.output) returns an empty string, causing os.makedirs to fail.

Apply this diff to guard the directory creation:

 print(f"Saving vLLM checkpoint to: {args.output}")
-os.makedirs(os.path.dirname(args.output), exist_ok=True)
+out_dir = os.path.dirname(args.output)
+if out_dir:
+    os.makedirs(out_dir, exist_ok=True)
 torch.save(vllm_state_dict, args.output)
examples/vllm_serve/vllm_serve_fakequant.py (4)

110-111: Fix inverted pad_token condition.

The condition at Line 110 is inverted. It sets pad_token when it is NOT "<unk>" or when it's None. The intent is to replace pad_token when it IS missing or IS "<unk>".

Apply this diff to correct the logic:

-if tokenizer.pad_token != "<unk>" or tokenizer.pad_token is None:
+if tokenizer.pad_token is None or tokenizer.pad_token == "<unk>":
     tokenizer.pad_token = tokenizer.eos_token

177-178: Add error handling for unknown quant_format.

getattr(mtq, quant_config["quant_format"]) at Line 177 will raise an unclear AttributeError if the format name is invalid.

Apply this diff to provide a clear error message:

-quant_cfg = getattr(mtq, quant_config["quant_format"])
+try:
+    quant_cfg = getattr(mtq, quant_config["quant_format"])
+except AttributeError as e:
+    raise ValueError(
+        f"Unknown quant_format: {quant_config['quant_format']}. "
+        f"Available formats: {[attr for attr in dir(mtq) if attr.endswith('_CFG')]}"
+    ) from e

187-213: Verify amax key mappings, not just counts.

The amax loading logic only checks that counts match (Lines 201-205), but doesn't verify that the same keys exist in both checkpoint and model. This could allow loading mismatched amax values to wrong layers. Additionally, Lines 208-210 update all keys from saved_amax_dict, not just amax keys.

Apply this diff to validate key mappings:

 amax_file_path = quant_config["amax_file_path"]
 if amax_file_path:
     print(f"Loading amax values from {amax_file_path}")
     saved_amax_dict = torch.load(amax_file_path, map_location=self.device)
     current_state_dict = self.model.state_dict()
 
-    # Count amax keys in checkpoint and model
-    checkpoint_amax_keys = [key for key in saved_amax_dict if key.endswith("amax")]
-    model_amax_keys = [key for key in current_state_dict if key.endswith("amax")]
-
-    checkpoint_amax_count = len(checkpoint_amax_keys)
-    model_amax_count = len(model_amax_keys)
-
-    # Ensure counts match
-    if checkpoint_amax_count != model_amax_count:
-        raise ValueError(
-            f"Mismatch in amax key counts: checkpoint has {checkpoint_amax_count} "
-            f"amax keys but model has {model_amax_count} amax keys. "
-        )
-
-    # Update amax values
-    for key, value in saved_amax_dict.items():
-        if key in current_state_dict:
-            current_state_dict[key] = value.to(self.device)
+    # Filter to amax keys only
+    saved_amax = {k: v for k, v in saved_amax_dict.items() if k.endswith("_amax") or k.endswith("amax")}
+    model_amax_keys = {k for k in current_state_dict if k.endswith("_amax") or k.endswith("amax")}
+
+    missing_in_model = set(saved_amax.keys()) - model_amax_keys
+    missing_in_checkpoint = model_amax_keys - set(saved_amax.keys())
+    
+    if missing_in_model or missing_in_checkpoint:
+        error_msg = []
+        if missing_in_model:
+            error_msg.append(f"Keys in checkpoint not found in model: {sorted(list(missing_in_model))[:5]}")
+        if missing_in_checkpoint:
+            error_msg.append(f"Keys in model not found in checkpoint: {sorted(list(missing_in_checkpoint))[:5]}")
+        raise ValueError("\n".join(error_msg))
+
+    # Update amax values only
+    with torch.no_grad():
+        for key, value in saved_amax.items():
+            current_state_dict[key] = value.to(self.device)
 
     self.model.load_state_dict(current_state_dict, strict=True)

241-246: Remove duplicate 'model' positional argument.

Line 244 adds a positional "model" argument, but Line 245's make_arg_parser() also adds a "model" positional, causing an argparse conflict.

Apply this diff to remove the duplicate:

 def main():
     # Create parser that handles both quant and serve arguments
-    parser = FlexibleArgumentParser(description="vLLM model server with quantization support")
-    parser.add_argument("model", type=str, help="The path or name of the model to serve")
-    parser = make_arg_parser(parser)
+    parser = make_arg_parser(
+        FlexibleArgumentParser(description="vLLM model server with quantization support")
+    )
 
     # Parse arguments
     args = parser.parse_args()
🧹 Nitpick comments (1)
examples/vllm_serve/vllm_serve_fakequant.py (1)

75-95: Make disable_compilation tolerant of unknown model types.

The context manager raises ValueError at Line 86 if the model lacks expected attributes. This prevents usage with custom or future model architectures.

A past review suggested making this tolerant by falling back to the object itself. Consider applying this pattern:

 @contextmanager
-def disable_compilation(model):
-    """Context manager to temporarily disable torch.compile"""
-    do_not_compile = True
-    if hasattr(model, "model"):
-        do_not_compile = model.model.do_not_compile
-        model.model.do_not_compile = True
-    elif hasattr(model, "language_model"):  # VLM requires this
-        do_not_compile = model.language_model.model.do_not_compile
-        model.language_model.model.do_not_compile = True
-    else:
-        raise ValueError("Model does not have a model or language_model attribute")
-
-    try:
-        yield
-    finally:
-        if hasattr(model, "model"):
-            model.model.do_not_compile = do_not_compile
-        elif hasattr(model, "language_model"):
-            model.language_model.model.do_not_compile = do_not_compile
+def disable_compilation(obj):
+    """Temporarily set do_not_compile on the underlying model if available."""
+    target = None
+    if hasattr(obj, "model"):
+        target = obj.model
+    elif hasattr(obj, "language_model") and hasattr(obj.language_model, "model"):
+        target = obj.language_model.model
+    else:
+        target = obj  # try the object itself
+    old = getattr(target, "do_not_compile", False)
+    try:
+        setattr(target, "do_not_compile", True)
+        yield
+    finally:
+        setattr(target, "do_not_compile", old)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5279884 and d02bc95.

📒 Files selected for processing (6)
  • .github/CODEOWNERS (1 hunks)
  • CHANGELOG.rst (1 hunks)
  • examples/vllm_serve/Dockerfile (1 hunks)
  • examples/vllm_serve/README.md (1 hunks)
  • examples/vllm_serve/convert_amax_hf2vllm.py (1 hunks)
  • examples/vllm_serve/vllm_serve_fakequant.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • examples/vllm_serve/README.md
  • examples/vllm_serve/Dockerfile
🧰 Additional context used
🧬 Code graph analysis (1)
examples/vllm_serve/vllm_serve_fakequant.py (3)
examples/llm_autodeploy/api_server.py (1)
  • run_server (194-208)
modelopt/torch/utils/dataset_utils.py (1)
  • get_dataset_dataloader (157-232)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
  • to (115-123)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs

Comment on lines +71 to +76
for merged_key, key_value_pairs in merge_groups.items():
if len(key_value_pairs) > 1:
# Take the maximum across all values for this merged key
values = [value for _, value in key_value_pairs]
merged_value = torch.stack(values).max(dim=0)[0]
vllm_state_dict[merged_key] = merged_value
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add shape validation before stacking tensors.

torch.stack() at Line 75 requires all tensors to have identical shapes. If the amax values from different projections have mismatched shapes, a runtime error will occur.

Apply this diff to add shape validation:

 for merged_key, key_value_pairs in merge_groups.items():
     if len(key_value_pairs) > 1:
         # Take the maximum across all values for this merged key
         values = [value for _, value in key_value_pairs]
+        # Validate shapes before stacking
+        first_shape = values[0].shape
+        if not all(v.shape == first_shape for v in values):
+            mismatched = [(k, v.shape) for k, v in key_value_pairs]
+            raise ValueError(f"Cannot merge {merged_key}: shape mismatch {mismatched}")
         merged_value = torch.stack(values).max(dim=0)[0]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for merged_key, key_value_pairs in merge_groups.items():
if len(key_value_pairs) > 1:
# Take the maximum across all values for this merged key
values = [value for _, value in key_value_pairs]
merged_value = torch.stack(values).max(dim=0)[0]
vllm_state_dict[merged_key] = merged_value
for merged_key, key_value_pairs in merge_groups.items():
if len(key_value_pairs) > 1:
# Take the maximum across all values for this merged key
values = [value for _, value in key_value_pairs]
# Validate shapes before stacking
first_shape = values[0].shape
if not all(v.shape == first_shape for v in values):
mismatched = [(k, v.shape) for k, v in key_value_pairs]
raise ValueError(f"Cannot merge {merged_key}: shape mismatch {mismatched}")
merged_value = torch.stack(values).max(dim=0)[0]
vllm_state_dict[merged_key] = merged_value
🤖 Prompt for AI Agents
In examples/vllm_serve/convert_amax_hf2vllm.py around lines 71 to 76, add
explicit shape validation before calling torch.stack: for each merged_key,
collect the shapes of tensors in values and ensure they are all identical; if
the shapes differ, raise a clear ValueError that includes the merged_key and the
list of tensor shapes so the caller can debug, otherwise proceed to
torch.stack(values).max(dim=0)[0] and assign to vllm_state_dict[merged_key].

/examples/speculative_decoding @NVIDIA/modelopt-torch-speculative-codeowners
/examples/vlm_ptq @NVIDIA/modelopt-examples-vlm-codeowners
/examples/windows @NVIDIA/modelopt-windows-codeowners
/examples/windows @NVIDIA/modelopt-examples-llm_ptq-codeowners
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a typo

Signed-off-by: Keval Morabia <[email protected]>
@kevalmorabia97 kevalmorabia97 enabled auto-merge (squash) October 7, 2025 17:50
@kevalmorabia97 kevalmorabia97 merged commit 1537885 into main Oct 7, 2025
27 checks passed
@kevalmorabia97 kevalmorabia97 deleted the huizim/vllm_serve_update branch October 7, 2025 18:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants